{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.6 損失関数の実装\n",
    "\n",
    "- 本ファイルでは、SSDネットワークモデルの損失関数を作成します。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.6 学習目標\n",
    "\n",
    "1.\tjaccard係数を用いたmatch関数の動作を理解する\n",
    "2.\tHard Negative Miningを理解する\n",
    "3.\t2種類の損失関数（SmoothL1Loss関数、交差エントロピー誤差関数）の働きを理解する\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 事前準備\n",
    "\n",
    "\n",
    "フォルダ「utils」のmatch.pyを使用します"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# パッケージのimport\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# フォルダ「utils」にある関数matchを記述したmatch.pyからimport\n",
    "from utils.match import match"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 損失関数クラスMultiBoxLossを実装する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiBoxLoss(nn.Module):\n",
    "    \"\"\"SSDの損失関数のクラスです。\"\"\"\n",
    "\n",
    "    def __init__(self, jaccard_thresh=0.5, neg_pos=3, device='cpu'):\n",
    "        super(MultiBoxLoss, self).__init__()\n",
    "        self.jaccard_thresh = jaccard_thresh  # 0.5 関数matchのjaccard係数の閾値\n",
    "        self.negpos_ratio = neg_pos  # 3:1 Hard Negative Miningの負と正の比率\n",
    "        self.device = device  # CPUとGPUのいずれで計算するのか\n",
    "\n",
    "    def forward(self, predictions, targets):\n",
    "        \"\"\"\n",
    "        損失関数の計算。\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        predictions : SSD netの訓練時の出力(tuple)\n",
    "            (loc=torch.Size([num_batch, 8732, 4]), conf=torch.Size([num_batch, 8732, 21]), dbox_list=torch.Size [8732,4])。\n",
    "\n",
    "        targets : [num_batch, num_objs, 5]\n",
    "            5は正解のアノテーション情報[xmin, ymin, xmax, ymax, label_ind]を示す\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        loss_l : テンソル\n",
    "            locの損失の値\n",
    "        loss_c : テンソル\n",
    "            confの損失の値\n",
    "\n",
    "        \"\"\"\n",
    "\n",
    "        # SSDモデルの出力がタプルになっているので、個々にばらす\n",
    "        loc_data, conf_data, dbox_list = predictions\n",
    "\n",
    "        # 要素数を把握\n",
    "        num_batch = loc_data.size(0)  # ミニバッチのサイズ\n",
    "        num_dbox = loc_data.size(1)  # DBoxの数 = 8732\n",
    "        num_classes = conf_data.size(2)  # クラス数 = 21\n",
    "\n",
    "        # 損失の計算に使用するものを格納する変数を作成\n",
    "        # conf_t_label：各DBoxに一番近い正解のBBoxのラベルを格納させる\n",
    "        # loc_t:各DBoxに一番近い正解のBBoxの位置情報を格納させる\n",
    "        conf_t_label = torch.LongTensor(num_batch, num_dbox).to(self.device)\n",
    "        loc_t = torch.Tensor(num_batch, num_dbox, 4).to(self.device)\n",
    "\n",
    "        # loc_tとconf_t_labelに、\n",
    "        # DBoxと正解アノテーションtargetsをmatchさせた結果を上書きする\n",
    "        for idx in range(num_batch):  # ミニバッチでループ\n",
    "\n",
    "            # 現在のミニバッチの正解アノテーションのBBoxとラベルを取得\n",
    "            truths = targets[idx][:, :-1].to(self.device)  # BBox\n",
    "            # ラベル [物体1のラベル, 物体2のラベル, …]\n",
    "            labels = targets[idx][:, -1].to(self.device)\n",
    "\n",
    "            # デフォルトボックスを新たな変数で用意\n",
    "            dbox = dbox_list.to(self.device)\n",
    "\n",
    "            # 関数matchを実行し、loc_tとconf_t_labelの内容を更新する\n",
    "            # （詳細）\n",
    "            # loc_t:各DBoxに一番近い正解のBBoxの位置情報が上書きされる\n",
    "            # conf_t_label：各DBoxに一番近いBBoxのラベルが上書きされる\n",
    "            # ただし、一番近いBBoxとのjaccard overlapが0.5より小さい場合は\n",
    "            # 正解BBoxのラベルconf_t_labelは背景クラスの0とする\n",
    "            variance = [0.1, 0.2]\n",
    "            # このvarianceはDBoxからBBoxに補正計算する際に使用する式の係数です\n",
    "            match(self.jaccard_thresh, truths, dbox,\n",
    "                  variance, labels, loc_t, conf_t_label, idx)\n",
    "\n",
    "        # ----------\n",
    "        # 位置の損失：loss_lを計算\n",
    "        # Smooth L1関数で損失を計算する。ただし、物体を発見したDBoxのオフセットのみを計算する\n",
    "        # ----------\n",
    "        # 物体を検出したBBoxを取り出すマスクを作成\n",
    "        pos_mask = conf_t_label > 0  # torch.Size([num_batch, 8732])\n",
    "\n",
    "        # pos_maskをloc_dataのサイズに変形\n",
    "        pos_idx = pos_mask.unsqueeze(pos_mask.dim()).expand_as(loc_data)\n",
    "\n",
    "        # Positive DBoxのloc_dataと、教師データloc_tを取得\n",
    "        loc_p = loc_data[pos_idx].view(-1, 4)\n",
    "        loc_t = loc_t[pos_idx].view(-1, 4)\n",
    "\n",
    "        # 物体を発見したPositive DBoxのオフセット情報loc_tの損失（誤差）を計算\n",
    "        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')\n",
    "\n",
    "        # ----------\n",
    "        # クラス予測の損失：loss_cを計算\n",
    "        # 交差エントロピー誤差関数で損失を計算する。ただし、背景クラスが正解であるDBoxが圧倒的に多いので、\n",
    "        # Hard Negative Miningを実施し、物体発見DBoxと背景クラスDBoxの比が1:3になるようにする。\n",
    "        # そこで背景クラスDBoxと予想したもののうち、損失が小さいものは、クラス予測の損失から除く\n",
    "        # ----------\n",
    "        batch_conf = conf_data.view(-1, num_classes)\n",
    "\n",
    "        # クラス予測の損失を関数を計算(reduction='none'にして、和をとらず、次元をつぶさない)\n",
    "        loss_c = F.cross_entropy(\n",
    "            batch_conf, conf_t_label.view(-1), reduction='none')\n",
    "\n",
    "        # -----------------\n",
    "        # これからNegative DBoxのうち、Hard Negative Miningで抽出するものを求めるマスクを作成します\n",
    "        # -----------------\n",
    "\n",
    "        # 物体発見したPositive DBoxの損失を0にする\n",
    "        # （注意）物体はlabelが1以上になっている。ラベル0は背景。\n",
    "        num_pos = pos_mask.long().sum(1, keepdim=True)  # ミニバッチごとの物体クラス予測の数\n",
    "        loss_c = loss_c.view(num_batch, -1)  # torch.Size([num_batch, 8732])\n",
    "        loss_c[pos_mask] = 0  # 物体を発見したDBoxは損失0とする\n",
    "\n",
    "        # Hard Negative Miningを実施する\n",
    "        # 各DBoxの損失の大きさloss_cの順位であるidx_rankを求める\n",
    "        _, loss_idx = loss_c.sort(1, descending=True)\n",
    "        _, idx_rank = loss_idx.sort(1)\n",
    "\n",
    "        # （注釈）\n",
    "        # 実装コードが特殊で直感的ではないです。\n",
    "        # 上記2行は、要は各DBoxに対して、損失の大きさが何番目なのかの情報を\n",
    "        # 変数idx_rankとして高速に取得したいというコードです。\n",
    "        #\n",
    "        # DBOXの損失値の大きい方から降順に並べ、DBoxの降順のindexをloss_idxに格納。\n",
    "        # 損失の大きさloss_cの順位であるidx_rankを求める。\n",
    "        # ここで、\n",
    "        # 降順になった配列indexであるloss_idxを、0から8732まで昇順に並べ直すためには、\n",
    "        # 何番目のloss_idxのインデックスをとってきたら良いのかを示すのが、idx_rankである。\n",
    "        # 例えば、\n",
    "        # idx_rankの要素0番目 = idx_rank[0]を求めるには、loss_idxの値が0の要素、\n",
    "        # つまりloss_idx[?}=0 の、?は何番かを求めることになる。ここで、? = idx_rank[0]である。\n",
    "        # いま、loss_idx[?]=0の0は、元のloss_cの要素の0番目という意味である。\n",
    "        # つまり?は、元のloss_cの要素0番目は、降順に並び替えられたloss_idxの何番目ですか\n",
    "        # を求めていることになり、 結果、\n",
    "        # ? = idx_rank[0] はloss_cの要素0番目が、降順の何番目かを示すことになる。\n",
    "\n",
    "        # 背景のDBoxの数num_negを決める。HardNegative Miningにより、\n",
    "        # 物体発見のDBoxの数num_posの3倍（self.negpos_ratio倍）とする。\n",
    "        # ただし、万が一、DBoxの数を超える場合は、DBoxの数を上限とする\n",
    "        num_neg = torch.clamp(num_pos*self.negpos_ratio, max=num_dbox)\n",
    "\n",
    "        # idx_rankは各DBoxの損失の大きさが上から何番目なのかが入っている\n",
    "        # 背景のDBoxの数num_negよりも、順位が低い（すなわち損失が大きい）DBoxを取るマスク作成\n",
    "        # torch.Size([num_batch, 8732])\n",
    "        neg_mask = idx_rank < (num_neg).expand_as(idx_rank)\n",
    "\n",
    "        # -----------------\n",
    "        # （終了）これからNegative DBoxのうち、Hard Negative Miningで抽出するものを求めるマスクを作成します\n",
    "        # -----------------\n",
    "\n",
    "        # マスクの形を整形し、conf_dataに合わせる\n",
    "        # pos_idx_maskはPositive DBoxのconfを取り出すマスクです\n",
    "        # neg_idx_maskはHard Negative Miningで抽出したNegative DBoxのconfを取り出すマスクです\n",
    "        # pos_mask：torch.Size([num_batch, 8732])→pos_idx_mask：torch.Size([num_batch, 8732, 21])\n",
    "        pos_idx_mask = pos_mask.unsqueeze(2).expand_as(conf_data)\n",
    "        neg_idx_mask = neg_mask.unsqueeze(2).expand_as(conf_data)\n",
    "\n",
    "        # conf_dataからposとnegだけを取り出してconf_hnmにする。形はtorch.Size([num_pos+num_neg, 21])\n",
    "        conf_hnm = conf_data[(pos_idx_mask+neg_idx_mask).gt(0)\n",
    "                             ].view(-1, num_classes)\n",
    "        # （注釈）gtは greater than (>)の略称。これでmaskが1のindexを取り出す。\n",
    "        # pos_idx_mask+neg_idx_maskは足し算だが、indexへのmaskをまとめているだけである。\n",
    "        # つまり、posであろうがnegであろうが、マスクが1のものを足し算で一つのリストにし、それをgtで取得\n",
    "\n",
    "        # 同様に教師データであるconf_t_labelからposとnegだけを取り出してconf_t_label_hnmに\n",
    "        # 形はtorch.Size([pos+neg])になる\n",
    "        conf_t_label_hnm = conf_t_label[(pos_mask+neg_mask).gt(0)]\n",
    "\n",
    "        # confidenceの損失関数を計算（要素の合計=sumを求める）\n",
    "        loss_c = F.cross_entropy(conf_hnm, conf_t_label_hnm, reduction='sum')\n",
    "\n",
    "        # 物体を発見したBBoxの数N（全ミニバッチの合計）で損失を割り算\n",
    "        N = num_pos.sum()\n",
    "        loss_l /= N\n",
    "        loss_c /= N\n",
    "\n",
    "        return loss_l, loss_c\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
